import os
import time
import random
import numpy as np
import torch
from loaders import *
from models import *
from reporters import *
from experiments import *


def make_dataset(root_dir, dataset_name, params):
    # load the data
    data = torch.load(os.path.join(root_dir, "data", dataset_name+".pt"))[0]
    if "bipartite_loader" in params.keys():
        if params["bipartite_loader"]:
            data = bipartite_reindexer(data)

    train_data, val_data, test_data = data.train_val_test_split(
        val_ratio=params["val_ratio"],
        test_ratio=params["test_ratio"]
    )

    if "schedule_batching" in params.keys():
        if params["schedule_batching"]:
            train_loader = ScheduledLoader(train_data, **params)
            val_loader = TemporalDataLoader(val_data, **params)
            test_loader = TemporalDataLoader(test_data, **params)
            return data, {"train_loader": train_loader, "val_loader": val_loader, "test_loader": test_loader}

    train_loader = TemporalDataLoader(train_data, **params)
    val_loader = TemporalDataLoader(val_data, **params)
    test_loader = TemporalDataLoader(test_data, **params)

    return data, {"train_loader": train_loader, "val_loader": val_loader, "test_loader": test_loader}


def make_model(model_name, data, device, params):
    # construct the model
    if model_name == "MLP":
        msg_dim = data.msg.size(-1)
        model = MLP(msg_dim, device=device, **params)
    elif model_name == "TGN":
        num_nodes = data.num_nodes
        msg_dim = data.msg.size(-1)
        model = TGN(num_nodes, msg_dim, device=device, storage=data, **params)
    elif model_name == "GraphMixer":
        num_nodes = data.num_nodes
        msg_dim = data.msg.size(-1)
        model = GraphMixer(num_nodes, msg_dim, device=device, **params)
    elif model_name == "ProfileBuilder":
        num_src = data.src.unique().size(0)
        num_dst = data.dst.unique().size(0)
        msg_dim = data.msg.size(-1)
        model = ProfileBuilder(num_src, num_dst, msg_dim, device=device, storage=data, **params)
    elif model_name == "GraphProfiler":
        num_src = data.src.unique().size(0)
        num_dst = data.dst.unique().size(0)
        msg_dim = data.msg.size(-1)
        model = GraphProfiler(num_src, num_dst, msg_dim, device=device, **params)
    else:
        raise NotImplementedError
    return model


def make_optimizer(model, params):
    optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"], weight_decay=params["weight_decay"])
    return optimizer


def make_reporter(experiment_dir, cfg):
    reporter = SimpleReporter(experiment_dir=experiment_dir, cfg=cfg)
    return reporter


def make_dir(root_dir, experiment_name, extra=None):
    if extra is None:
        experiment_dir = os.path.join(root_dir, "results", experiment_name, "runners",
                                    time.strftime('%Y_%m_%d_%H_%M_%S') + '_rid_' + str(random.random()).split('.')[1])
    else:
        experiment_dir = os.path.join(root_dir, "results", experiment_name,
                                      time.strftime('%Y_%m_%d_%H_%M_%S') + '_rid_' + str(random.random()).split('.')[1])
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)
    return experiment_dir
    
    
def make_experiment(model, optimizer, reporter, device, experiment_dir, params):
    experiment = RLCExperiment(model=model, optimizer=optimizer, reporter=reporter, device=device, experiment_dir=experiment_dir, **params)
    return experiment


def set_seed(seed, multiple_device=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if multiple_device:
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


def set_device(device_idx=None):
    if device_idx is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(f'cuda:{device_idx}')

    return device








